import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

class MinecraftNet(nn.Module):
    """docstring for Net"""
    def __init__(self, goal_dim, action_dim, H, hidden_size = 256):
        super(MinecraftNet, self).__init__()
        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), np.sqrt(2))
        self.goal_dim = goal_dim
        self.action_dim = action_dim
        self.H = H
        self.hidden_size = hidden_size

        self.critic = nn.Sequential(
                        #nn.Linear(52 + goal_dim + 1, hidden_size),
                        nn.Linear(96 + goal_dim + 1, hidden_size),
                        nn.ReLU(),
                        nn.Linear(hidden_size, hidden_size),
                        nn.ReLU(),
                        nn.Linear(hidden_size, hidden_size),
                        nn.ReLU(),
                        nn.Linear(hidden_size, action_dim),
                        nn.Sigmoid()
                        )

    def forward(self, x, goal):
        B = x.shape[0]
        goal_one_hot = nn.functional.one_hot(goal.long(), num_classes=self.goal_dim+1).view(B, self.goal_dim+1).float()
        x = torch.cat([x, goal_one_hot], dim=1)
        return -self.critic(x) * self.H
